{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b73540d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# type: ignore"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "44ce8697",
   "metadata": {},
   "source": [
    "# OpenAI Supervised Fine-Tuning using Direct Preference Optimization (DPO)\n",
    "\n",
    "This recipe allows TensorZero users to fine-tune OpenAI models using Direct Preference Optimization (DPO) and their own data. Since TensorZero automatically logs all inferences and feedback, it is straightforward to fine-tune a model using your own data and any prompt you want.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8e30dcbe",
   "metadata": {},
   "source": [
    "To get started:\n",
    "\n",
    "- Set the `TENSORZERO_CLICKHOUSE_URL` environment variable. For example: `TENSORZERO_CLICKHOUSE_URL`=`\"http://chuser:chpassword@localhost:8123/tensorzero\"`\n",
    "- Set the `OPENAI_API_KEY` environment variable.\n",
    "- Update the following parameters:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3bf0acb",
   "metadata": {},
   "outputs": [],
   "source": [
    "CONFIG_PATH = \"../../../ui/fixtures/config/tensorzero.toml\"\n",
    "\n",
    "FUNCTION_NAME = \"extract_entities\"\n",
    "\n",
    "# The name of the variant to use to grab the templates used for fine-tuning\n",
    "TEMPLATE_VARIANT_NAME = \"gpt_4o_mini\"  # It's OK that this variant uses a different model than the one we're fine-tuning\n",
    "\n",
    "# Fraction of the data to use for validation\n",
    "VAL_FRACTION = 0.2\n",
    "\n",
    "# Maximum number of samples to use for fine-tuning\n",
    "MAX_SAMPLES = 1000\n",
    "\n",
    "#  Model \"gpt-4o-2024-08-06\" is to our knowledge the only base model supported for this method.\n",
    "#  You can can use the base model as below or fine-tunes derived from it for this recipe.\n",
    "MODEL_NAME = \"gpt-4o-2024-08-06\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "365a71f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "import random\n",
    "import tempfile\n",
    "import time\n",
    "from pprint import pprint\n",
    "from typing import Any, Dict, List\n",
    "\n",
    "import openai\n",
    "import toml\n",
    "from IPython.display import clear_output\n",
    "from tensorzero import ContentBlock, RenderedSample, TensorZeroGateway"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc712df7",
   "metadata": {},
   "outputs": [],
   "source": [
    "assert \"TENSORZERO_CLICKHOUSE_URL\" in os.environ, \"TENSORZERO_CLICKHOUSE_URL environment variable not set\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "152d13d9",
   "metadata": {},
   "source": [
    "Initialize the TensorZero client\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4471a76",
   "metadata": {},
   "outputs": [],
   "source": [
    "t0 = TensorZeroGateway.build_embedded(clickhouse_url=os.environ[\"TENSORZERO_CLICKHOUSE_URL\"], config_file=CONFIG_PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "835e3e38",
   "metadata": {},
   "outputs": [],
   "source": [
    "inferences = t0.experimental_list_inferences(\n",
    "    function_name=FUNCTION_NAME,\n",
    "    output_source=\"demonstration\",  # Since we're using DPO we need pairwise data so we must use demonstrations\n",
    "    limit=MAX_SAMPLES,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "52e576c7",
   "metadata": {},
   "source": [
    "OpenAI requires the fine-tuning data (for DPO) to be structured in this [format](https://platform.openai.com/docs/guides/fine-tuning#preference)\n",
    "\n",
    "```\n",
    "{\n",
    "  \"input\": {\n",
    "    \"messages\": [\n",
    "      {\n",
    "        \"role\": \"user\",\n",
    "        \"content\": \"<string>\"\n",
    "      }\n",
    "    ],\n",
    "    \"tools\": [],\n",
    "    \"parallel_tool_calls\": true\n",
    "  },\n",
    "  \"preferred_output\": [\n",
    "    {\n",
    "      \"role\": \"assistant\",\n",
    "      \"content\": \"<string>\"\n",
    "    }\n",
    "  ],\n",
    "  \"non_preferred_output\": [\n",
    "    {\n",
    "      \"role\": \"assistant\",\n",
    "      \"content\": \"<string>\"\n",
    "    }\n",
    "  ]\n",
    "}\n",
    "\n",
    "```\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1abda026",
   "metadata": {},
   "outputs": [],
   "source": [
    "rendered_samples = t0.experimental_render_samples(\n",
    "    stored_samples=inferences, variants={FUNCTION_NAME: TEMPLATE_VARIANT_NAME}\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e157434b",
   "metadata": {},
   "source": [
    "Split data into training and validation sets for fine-tuning\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6a5546d",
   "metadata": {},
   "outputs": [],
   "source": [
    "random.shuffle(rendered_samples)\n",
    "train_samples = rendered_samples[: int(len(rendered_samples) * (1 - VAL_FRACTION))]\n",
    "val_samples = rendered_samples[int(len(rendered_samples) * (1 - VAL_FRACTION)) :]\n",
    "\n",
    "print(f\"Training set size: {len(train_samples)}\")\n",
    "print(f\"Validation set size: {len(val_samples)}\")\n",
    "print(f\"Actual validation fraction: {len(val_samples) / len(rendered_samples):.2f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a583156d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def prepare_output(output: List[ContentBlock]) -> Dict[str, Any]:\n",
    "    content = []\n",
    "    tool_calls = []\n",
    "\n",
    "    for block in output:\n",
    "        if block.type == \"text\":\n",
    "            content.append({\"type\": \"text\", \"text\": block.text})\n",
    "        elif block.type == \"thought\":\n",
    "            content.append({\"type\": \"text\", \"text\": f\"<think>{block.text}</think>\"})\n",
    "        elif block.type == \"tool_call\":\n",
    "            tool_calls.append(\n",
    "                {\n",
    "                    \"function\": {\n",
    "                        \"arguments\": json.dumps(block.arguments),\n",
    "                        \"name\": block.name,\n",
    "                    },\n",
    "                    \"id\": block.id,\n",
    "                    \"type\": \"function\",\n",
    "                }\n",
    "            )\n",
    "        else:\n",
    "            raise ValueError(f\"Unsupported content type: {block.type}\")\n",
    "\n",
    "    output_message: Dict[str, Any] = {\"role\": \"assistant\"}\n",
    "    if content:\n",
    "        output_message[\"content\"] = content\n",
    "    if tool_calls:\n",
    "        output_message[\"tool_calls\"] = tool_calls\n",
    "\n",
    "    return output_message\n",
    "\n",
    "\n",
    "def sample_to_openai_messages(sample: RenderedSample) -> Dict[str, Any]:\n",
    "    result = {\n",
    "        \"input\": {\"messages\": [], \"tools\": [], \"parallel_tool_calls\": True},\n",
    "        \"preferred_output\": [],\n",
    "        \"non_preferred_output\": [],\n",
    "    }\n",
    "\n",
    "    if sample.input.system:\n",
    "        result[\"input\"][\"messages\"].append({\"role\": \"system\", \"content\": sample.input.system})\n",
    "    for message in sample.input.messages:\n",
    "        content = []\n",
    "        for part in message.content:\n",
    "            if part.type == \"text\":\n",
    "                content.append(part.text)\n",
    "            else:\n",
    "                raise ValueError(f\"Unsupported content type: {part.type}\")\n",
    "        if len(content) != 1:\n",
    "            raise ValueError(f\"Expected exactly one content part for message {message}, got {len(content)}\")\n",
    "        result[\"input\"][\"messages\"].append({\"role\": message.role, \"content\": content[0]})\n",
    "\n",
    "    result[\"preferred_output\"].append(prepare_output(sample.output))\n",
    "    if len(sample.dispreferred_outputs) != 1:\n",
    "        raise ValueError(\n",
    "            f\"Expected exactly one dispreferred output for sample {sample}, got {len(sample.dispreferred_outputs)}\"\n",
    "        )\n",
    "    result[\"non_preferred_output\"].append(prepare_output(sample.dispreferred_outputs[0]))\n",
    "\n",
    "    return result\n",
    "\n",
    "\n",
    "def prepare_samples(samples: List[RenderedSample]) -> List[Dict[str, Any]]:\n",
    "    return [sample_to_openai_messages(sample) for sample in samples]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4fcf0566",
   "metadata": {},
   "outputs": [],
   "source": [
    "prepared_train_samples = prepare_samples(train_samples)\n",
    "prepared_val_samples = prepare_samples(val_samples)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7a8dac3e",
   "metadata": {},
   "source": [
    "Upload the prepared datasets to OpenAI.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b95ae94",
   "metadata": {},
   "outputs": [],
   "source": [
    "def upload_dataset_to_openai(samples, openai_client) -> str:\n",
    "    with tempfile.NamedTemporaryFile(mode=\"w\", suffix=\".jsonl\", delete=False) as f:\n",
    "        for item in samples:\n",
    "            json.dump(item, f)\n",
    "            f.write(\"\\n\")\n",
    "        f.flush()\n",
    "\n",
    "        print(f\"File persisted on path [{f.name}]\")\n",
    "\n",
    "        with open(f.name, \"rb\") as file:\n",
    "            file_object = openai_client.files.create(file=file, purpose=\"fine-tune\")\n",
    "\n",
    "        return file_object.id\n",
    "\n",
    "\n",
    "openai_client = openai.OpenAI()\n",
    "\n",
    "dpo_fine_tuning_object_id = upload_dataset_to_openai(prepared_train_samples, openai_client)\n",
    "val_file_object_id = upload_dataset_to_openai(prepared_val_samples, openai_client)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "73fba2d1",
   "metadata": {},
   "source": [
    "Launch the fine-tuning job and wait for it to complete.\n",
    "\n",
    "NOTE : This step takes a while and you can monitor the progress and estimated completion time using OpenAI's fine-tuning [dashboard](https://platform.openai.com/finetune/)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa877b58",
   "metadata": {},
   "outputs": [],
   "source": [
    "fine_tuning_job = openai_client.fine_tuning.jobs.create(\n",
    "    training_file=dpo_fine_tuning_object_id,\n",
    "    validation_file=val_file_object_id,\n",
    "    model=MODEL_NAME,\n",
    "    method={\n",
    "        \"type\": \"dpo\",\n",
    "        \"dpo\": {\n",
    "            \"hyperparameters\": {\"beta\": 0.2},\n",
    "        },\n",
    "    },\n",
    ")\n",
    "\n",
    "while True:\n",
    "    clear_output(wait=True)\n",
    "\n",
    "    try:\n",
    "        job_status = openai_client.fine_tuning.jobs.retrieve(fine_tuning_job.id)\n",
    "        pprint(job_status.to_dict())\n",
    "        if job_status.status in (\"succeeded\", \"failed\", \"cancelled\"):\n",
    "            break\n",
    "    except Exception as e:\n",
    "        print(f\"Error: {e}\")\n",
    "\n",
    "    time.sleep(10)\n",
    "\n",
    "print(f\"The fine-tuning job has compeleted with result {job_status.status}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d34c62a9",
   "metadata": {},
   "source": [
    "Once the fine-tuning job is complete, you can add the fine-tuned model to your config file.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "497f2111",
   "metadata": {},
   "outputs": [],
   "source": [
    "fine_tuned_model = job_status.fine_tuned_model\n",
    "model_config = {\n",
    "    \"models\": {\n",
    "        fine_tuned_model: {\n",
    "            \"routing\": [\"openai\"],\n",
    "            \"providers\": {\"openai\": {\"type\": \"openai\", \"model_name\": fine_tuned_model}},\n",
    "        }\n",
    "    }\n",
    "}\n",
    "\n",
    "print(toml.dumps(model_config))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "58b70ee0",
   "metadata": {},
   "source": [
    "You'll need to add this model to a new variant you define in your config.\n",
    "\n",
    "Then, you're all set!\n",
    "\n",
    "You can change the weight to enable a gradual rollout of the new model.\n",
    "\n",
    "You might also add other parameters (e.g. max_tokens, temperature) to the variant section in the config file.\n"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "cell_metadata_filter": "-all",
   "formats": "ipynb,py:percent",
   "main_language": "python"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
